#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
import torch
from typing import List
from ipex_llm.utils.common import invalidInputError


def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
    if hasattr(linears[0], "weight"):
        # For GPTQ model, it might be qweight
        new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
        if linears[0].bias is not None:
            new_linear = torch.nn.Linear(0, 0, bias=True)
            new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
            new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
        else:
            new_linear = torch.nn.Linear(0, 0, bias=False)
        new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
        new_linear.in_features = new_weight.size(1)
        new_linear.out_features = new_weight.size(0)
        return new_linear
    else:
        return None


def merge_qkv_base(module: torch.nn.Module, attention_class):
    if (
        isinstance(attention_class, str) and module.__class__.__name__ == attention_class
        or not isinstance(attention_class, str) and isinstance(module, attention_class)
    ):
        qkv_proj = merge_linear([
            module.q_proj,
            module.k_proj,
            module.v_proj,
        ])
        if qkv_proj is not None:
            module.qkv_proj = qkv_proj
            del module.q_proj, module.k_proj, module.v_proj


def padding_linear_hd(linear: torch.nn.Linear,
                      old_head_dim: int, new_head_dim: int) -> torch.nn.Linear:
    in_features, out_features = linear.in_features, linear.out_features

    weight = linear.weight.data
    weight = weight.view(-1, old_head_dim, in_features)
    new_weight = torch.empty([weight.size(0), new_head_dim, in_features],
                             dtype=weight.dtype, device=weight.device)
    new_weight[:, :old_head_dim, :] = weight
    new_weight[:, old_head_dim:, :] = 0
    new_weight = new_weight.view(-1, in_features)
    if linear.bias is not None:
        bias = linear.bias.data
        bias = bias.view(-1, old_head_dim)
        new_bias = torch.empty([bias.size(0), new_head_dim],
                               dtype=bias.dtype, device=bias.device)
        new_bias[:, :old_head_dim] = bias
        new_bias[:, old_head_dim:] = 0
        new_bias = new_bias.flatten()

        new_linear = torch.nn.Linear(0, 0, bias=True)
        new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
    else:
        new_linear = torch.nn.Linear(0, 0, bias=False)
    new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
    new_linear.in_features = new_weight.size(1)
    new_linear.out_features = new_weight.size(0)
    return new_linear


def padding_attention_hd_base(module: torch.nn.Module, attention_class,
                              old_head_dim: int, new_head_dim: int):
    if (
        isinstance(attention_class, str) and module.__class__.__name__ == attention_class
        or not isinstance(attention_class, str) and isinstance(module, attention_class)
    ) and module.head_dim == old_head_dim:
        module.q_proj = padding_linear_hd(module.q_proj, old_head_dim, new_head_dim)
        module.k_proj = padding_linear_hd(module.k_proj, old_head_dim, new_head_dim)
        module.v_proj = padding_linear_hd(module.v_proj, old_head_dim, new_head_dim)
        module.head_dim = new_head_dim
        module.old_head_dim = old_head_dim


def padding_mla_v_hd_base(module: torch.nn.Module, attention_class):
    if (
        isinstance(attention_class, str) and module.__class__.__name__ == attention_class
        or not isinstance(attention_class, str) and isinstance(module, attention_class)
    ):
        k_head_dim = module.q_head_dim
        v_head_dim = module.v_head_dim
        if v_head_dim < k_head_dim:
            kv_b_proj = module.kv_b_proj
            w = kv_b_proj.weight.data.view(module.num_heads,
                                           module.qk_nope_head_dim + module.v_head_dim,
                                           module.kv_lora_rank)
            k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1)
            new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank],
                                  dtype=v_w.dtype, device=v_w.device)
            new_v_w[:, :v_head_dim, :] = v_w
            new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank)

            new_kv_b_proj = torch.nn.Linear(0, 0, bias=False,
                                            dtype=new_w.dtype, device=new_w.device)
            new_kv_b_proj.in_features = new_w.size(1)
            new_kv_b_proj.out_features = new_w.size(0)
            new_kv_b_proj.weight = torch.nn.Parameter(new_w, False)

            module.kv_b_proj = new_kv_b_proj


def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
    bsz, num_heads, seq_len, head_dim = states.size()
    if head_dim == old_head_dim and old_head_dim < new_head_dim:
        new_states = torch.empty([bsz, num_heads, seq_len, new_head_dim],
                                 dtype=states.dtype, device=states.device)
        new_states[:, :, :, :old_head_dim] = states
        new_states[:, :, :, old_head_dim:] = 0
        return new_states
    return states


def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                   old_head_dim: int, new_head_dim: int):
    return (
        padding_states_hd(q, old_head_dim, new_head_dim),
        padding_states_hd(k, old_head_dim, new_head_dim),
        padding_states_hd(v, old_head_dim, new_head_dim),
    )


def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
    from ipex_llm.transformers.models.utils import mlp_fusion_check
    qtype = getattr(module.gate_proj, "qtype", None)
    if mlp_fusion_check(x, qtype, module.training):
        import xe_linear
        x_2d = x.contiguous().view(-1, x.size(-1))
        output = module.down_proj(
            xe_linear.mlp_forward_xpu(
                x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
                x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
                act, qtype
            )
        )
        return output.view(x.shape)
    else:
        return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))


def mlp_silu_forward(self, x: torch.Tensor):
    from ipex_llm.transformers.models.utils import SILU
    return fuse_mlp_base(self, SILU, x)


def mlp_gelu_forward(self, x: torch.Tensor):
    from ipex_llm.transformers.models.utils import GELU
    return fuse_mlp_base(self, GELU, x)


def attention_softmax(attn_weights: torch.Tensor):
    if attn_weights.is_contiguous() and attn_weights.device.type == "xpu":
        import xe_addons
        xe_addons.attn_softmax_inplaced(attn_weights)
    else:
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
                                                   dtype=torch.float32).to(attn_weights.dtype)
    return attn_weights


def rms_norm_forward(self, hidden_states: torch.Tensor):
    weight = self.weight
    if hasattr(self, "variance_epsilon"):
        eps = self.variance_epsilon
    elif hasattr(self, "epsilon"):
        eps = self.epsilon
    else:
        eps = self.eps

    if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
        import xe_addons
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
        output = xe_addons.rms_norm(weight, x_2d, eps)
        return output.reshape(hidden_states.shape)
    else:
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + eps)
        return weight * hidden_states.to(input_dtype)


def layer_norm_forward(self, hidden_states: torch.Tensor):
    if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
        import xe_addons
        hidden_size = math.prod(self.normalized_shape)
        x_2d = hidden_states.reshape(-1, hidden_size).contiguous()
        output = xe_addons.layer_norm(x_2d, self.weight, self.bias, self.eps)
        return output.reshape(hidden_states.shape)
    else:
        return torch.nn.functional.layer_norm(
            hidden_states, self.normalized_shape,
            self.weight, self.bias, self.eps
        )


def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
    max_kvs = 128
    padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
    if mask is None:
        if is_causal:
            mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
                              dtype=dtype, device=device)
            mask.triu_(1)
            mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
        elif seq_length != kv_length and seq_length <= 32:
            mask = None
        else:
            mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device)
            mask[..., kv_length:padding_kv_length] = torch.finfo(dtype).min
            mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
    else:
        if seq_length != kv_length and seq_length <= 32:
            mask = mask[..., :seq_length, :kv_length]
            mask = mask.expand([bsz, n_heads, seq_length, kv_length])
        elif mask.size(3) != padding_kv_length:
            new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
                                   dtype=dtype, device=device)
            new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
            new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
            new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
            mask.set_(new_mask)     # modify `mask` inplaced
        else:
            mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
    return mask


def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
                                 value: torch.Tensor, mask: torch.Tensor = None,
                                 is_causal: bool = False, scale: float = None) -> torch.Tensor:
    bsz, n_heads, seq_length, head_dim = query.shape
    _, n_kv_heads, kv_length, _ = key.shape

    dtype, device = query.dtype, query.device

    if (
        device.type == "xpu"
        and dtype in [torch.float, torch.half]
        and head_dim in [64, 80, 96, 128, 192, 256]
    ):
        # prepare scale
        scale = 1 / math.sqrt(head_dim) if scale is None else scale

        # prepare mask
        mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)

        # compute
        import xe_addons
        if is_causal:
            if key.dtype == torch.uint8:
                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
            else:
                attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
        elif seq_length != kv_length and seq_length <= 32:
            # todo: add further scale support
            if key.dtype == torch.uint8:
                attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
            else:
                attn_output = xe_addons.sdp(query, key, value, mask, scale)
        else:
            if key.dtype == torch.uint8:
                attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
            else:
                attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)

        return attn_output
    else:
        mask = mask[..., :seq_length, :kv_length] if mask is not None else None

        from ipex_llm.transformers.models.utils import repeat_kv
        if n_heads != n_kv_heads:
            key = repeat_kv(key, n_heads // n_kv_heads)
            value = repeat_kv(value, n_heads // n_kv_heads)

        if is_causal and mask is None:
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query, key, value, is_causal=is_causal, scale=scale
            )
        else:
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query, key, value, mask, scale=scale
            )
        attn_output = attn_output.to(dtype)    # workaround ipex 2.1's bug
        return attn_output


def linear_forward(x: torch.Tensor, weight: torch.Tensor, qtype: int, out_features: int):
    if weight.device.type == "xpu":
        new_shape = x.shape[:-1] + (out_features,)
        x = x.to(weight.device, dtype=torch.float16)
        x_2d = x.contiguous().view(-1, x.shape[-1])
        import xe_linear
        x = xe_linear.forward_new(x_2d, weight, qtype, out_features)
        x = x.view(new_shape)
        return x
    else:
        invalidInputError(False,
                          "Unsupported device type: only support weight on xpu device.")


def quantize_linear(weight: torch.Tensor, in_features: int, precision: str):
    from ipex_llm.transformers.low_bit_linear import FP4Params
    from ipex_llm.ggml.quantize import ggml_tensor_qtype

    invalidInputError(precision in ggml_tensor_qtype.keys(),
                      f"{precision} is not supported, "
                      f"only {ggml_tensor_qtype.keys()} are supported now.")
    qtype = ggml_tensor_qtype[precision]
    paramsLowBit = FP4Params(data=weight.data,
                             requires_grad=False,
                             quantized=False,
                             _shape=None,
                             convert_shape_only=False,
                             qtype=qtype,
                             in_features=in_features,
                             enable_scale_search=False).to("cpu")
    return paramsLowBit, qtype


def moe_group_topk(scores: torch.Tensor, e_score_correction_bias: torch.Tensor,
                   n_group: int, topk_group: int, top_k: int, norm_topk_prob: float,
                   routed_scaling_factor: float):
    import xe_addons
    topk_idx, topk_weight = xe_addons.moe_group_topk(
        scores, e_score_correction_bias,
        n_group, 2, topk_group, top_k,
        top_k > 1 and norm_topk_prob, 1e-20, routed_scaling_factor
    )
    return topk_idx, topk_weight


def rotary_two_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
                                   cos: torch.Tensor, sin: torch.Tensor,
                                   half_layout: bool):
    import xe_addons
    xe_addons.rotary_two_with_cache_inplaced(query_states, key_states,
                                             cos, sin, half_layout)


def rotary_half_with_cache_inplaced(query_states: torch.Tensor, key_states: torch.Tensor,
                                    cos: torch.Tensor, sin: torch.Tensor):
    import xe_addons
    from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced
    make_cache_contiguous_inplaced(cos, sin)
    xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)


def moe_softmax_topk(router_logits: torch.Tensor, top_k: int, norm_topk_prob: bool):
    import xe_addons
    selected_experts, routing_weights = xe_addons.moe_softmax_topk(
        router_logits, top_k, norm_topk_prob
    )
    return selected_experts, routing_weights


# q,k,v_proj should be ipex-llm quantized linears
def merge_quantized_qkv(q_proj, k_proj, v_proj, module):
    from ipex_llm.transformers.low_bit_linear import FP4Params
    from ipex_llm.ggml.quantize import ggml_tensor_qtype
    has_qtype = (hasattr(q_proj.weight, 'qtype')
                 and hasattr(k_proj.weight, 'qtype')
                 and hasattr(v_proj.weight, 'qtype'))
    invalidInputError((has_qtype
                       and q_proj.weight.qtype == k_proj.weight.qtype
                       and q_proj.weight.qtype == v_proj.weight.qtype
                       and q_proj.weight.qtype in ggml_tensor_qtype.values()),
                      f"{q_proj.weight.qtype} is not supported, "
                      f"only {ggml_tensor_qtype.values()} are supported now.")
    origin_device = q_proj.weight.device
    q_proj.weight = q_proj.weight.to('cpu')
    k_proj.weight = k_proj.weight.to('cpu')
    v_proj.weight = v_proj.weight.to('cpu')
    linears = [q_proj, k_proj, v_proj]
    new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
    if q_proj.has_bias:
        new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
        q_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
    new_out_features = sum(layer.out_features for layer in linears)
    new_params_low_bit = FP4Params(data=new_weight.data,
                                   requires_grad=False,
                                   quantized=True,
                                   _shape=[new_out_features, q_proj.in_features],
                                   convert_shape_only=False,
                                   qtype=q_proj.weight.qtype,
                                   in_features=q_proj.in_features,
                                   enable_scale_search=False)
    q_proj.out_features = new_out_features
    q_proj.weight = new_params_low_bit.to(origin_device)
    del module.q_proj.weight
    module.qkv_proj = module.q_proj
    del module.k_proj, module.v_proj, module.q_proj
